function [deblurred_image, coef, updated_dict, lucy_image] = direct_sparse(params, varargin)
% direct_sparse function is based on direct sparse algorithm 
% (<Direct sparse deblurring> by Yifei Lou) which solves the deblur problem 
% with known kernel.
% This function also tries to reconstruct the image with known kernel.
% The difference between this function and her algorithm 
% is in this code we deconvolute the blurred image to update
% the dictionary.
% 
% Required fields in PARAMS:
%  --------------------------
%    'blurry' - blurry image
%    'initdict' - initial dictionary
%    'kernel' - known kernel
% 
% Zhe Hu
% Computer Science Department
% University of California, Merced
% zhu@ucmerced.edu
% 
% October 2009

blurry_image = params.blurry;
initial_dict = params.initdict;
kernel = params.kernel;
a = params.rowskip;
b = params.colskip;

%% Parameter of the function
% patch size of clear dictionary
s_c = sqrt(size(initial_dict,1));

% size of dictionary
n = size(initial_dict, 2);

% size of kernel
s_k = size(kernel,1);

% size of blurred dictionary
s_b = s_c - s_k +1;


%% use blurry pic to build data
da = blurry_image;
[row_num,col_num] = size(da); 

% use blocks in blurry image to build a data
data = [];
for i=1:a:row_num-s_b+1
    for j=1:b:col_num-s_b+1
        block = da(i:i+s_b-1,j:j+s_b-1);
        block = im2col(block,[1 1],'distinct');
        data = [data; block];
    end
    block = da(i:i+s_b-1,col_num-s_b+1:col_num);
    block = im2col(block,[1 1],'distinct');
    data = [data; block];
end
i = row_num-s_b+1;
    for j=1:b:col_num-s_b+1
        block = da(i:i+s_b-1,j:j+s_b-1);
        block = im2col(block,[1 1],'distinct');
        data = [data; block];
    end
    block = da(i:i+s_b-1,col_num-s_b+1:col_num);
    block = im2col(block,[1 1],'distinct');
    data = [data; block];

block = da(row_num-s_b+1:row_num,col_num-s_b+1:col_num);
block = im2col(block,[1 1],'distinct');
data = [data; block];
data = data';

%% update the dictionary
% deconvolute the image
numit=20;
dampar=0.0005;
lim = ceil(size(kernel,1)/2);
weight=zeros(size(blurry_image));
weight(lim+1:end-lim,lim+1:end-lim)=1;

deconv_image = deconvlucy(blurry_image,kernel,numit,dampar,weight);
lucy_image = imadjust(deconv_image,stretchlim(deconv_image),stretchlim(blurry_image));

% use deconvoluted image to update initial dictionary
da = lucy_image;

deconv_data = [];
for i=1:a:row_num-s_c+1
    for j=1:b:col_num-s_c+1
        block = da(i:i+s_c-1,j:j+s_c-1);
        block = im2col(block,[1 1],'distinct');
        deconv_data = [deconv_data; block];
    end
end
deconv_data = deconv_data';

% new version K-SVD
control_param = 5;
params.data = deconv_data;
params.Tdata = control_param;
params.initdict = initial_dict;
params.iternum = 30;
params.memusage = 'high';

[initial_dict,g,err] = ksvd(params,'');
updated_dict = initial_dict;

%% generate the blurred dictionary and resized initial dictionary
blurdict = [];
resizedict = [];

for i=1:n
    basis = initial_dict(:,i);
    basis = col2im(basis,[s_c s_c],[s_c s_c],'distinct');
    convbasis = conv2(basis, kernel,'valid');
    convbasis = im2col(convbasis,[1 1],'distinct');
    blurdict = [blurdict;convbasis];
    
    resizebasis = basis(1:s_b,1:s_b); % can try some other part for the resizebasis
    resizebasis = im2col(resizebasis,[1,1],'distinct');
    resizedict = [resizedict;resizebasis];
end

blurdict = blurdict';
resizedict = resizedict';

% normalize the basis
bMag = sqrt(sum(blurdict.^2));
rMag = sqrt(sum(resizedict.^2));

blurdict = blurdict./bMag(ones(1,size(blurdict,1)),:);
resizedict = resizedict./rMag(ones(1,size(resizedict,1)),:);


%% use OMP to get coefficients
sparse_num = 20;

D = blurdict;
X = data;
G = D'*D;

coef = omp2(D'*X, sum(X.*X), G, 0.01,'maxatoms',sparse_num);

%% recover the original image

% calculate the weight and recover each pixel
Outstruct = {};
Outimage = zeros(row_num,col_num);
weight = zeros(row_num,col_num);
colnum=0;
for i=1:a:row_num-s_b+1
    for j=1:b:col_num-s_b+1
        colnum = colnum + 1;
%         fprintf('(%i, %i) = %i \n' ,i,j,colnum);
        Outstruct(i,j).coefficient = coef(:,colnum);
        Outstruct(i,j).norm = zeronorm(Outstruct(i,j).coefficient);
        temp = resizedict*Outstruct(i,j).coefficient;
        temp = col2im(temp,[s_b s_b],[s_b s_b],'distinct');
        Outimage(i:i+s_b-1,j:j+s_b-1) = Outimage(i:i+s_b-1,j:j+s_b-1)+temp*exp((1-Outstruct(i,j).norm)/control_param);
        weight(i:i+s_b-1,j:j+s_b-1) = weight(i:i+s_b-1,j:j+s_b-1)+ exp((1-Outstruct(i,j).norm)/control_param);
    end
    j = col_num-s_b+1;
    colnum = colnum + 1;
    Outstruct(i,j).coefficient = coef(:,colnum);
    Outstruct(i,j).norm = zeronorm(Outstruct(i,j).coefficient);
    temp = resizedict*Outstruct(i,j).coefficient;
    temp = col2im(temp,[s_b s_b],[s_b s_b],'distinct');
    Outimage(i:i+s_b-1,j:j+s_b-1) = Outimage(i:i+s_b-1,j:j+s_b-1)+temp*exp((1-Outstruct(i,j).norm)/control_param);
    weight(i:i+s_b-1,j:j+s_b-1) = weight(i:i+s_b-1,j:j+s_b-1)+ exp((1-Outstruct(i,j).norm)/control_param);
end
i = row_num-s_b+1;
    for j=1:b:col_num-s_b+1
        colnum = colnum + 1;
%         fprintf('(%i, %i) = %i \n' ,i,j,colnum);
        Outstruct(i,j).coefficient = coef(:,colnum);
        Outstruct(i,j).norm = zeronorm(Outstruct(i,j).coefficient);
        temp = resizedict*Outstruct(i,j).coefficient;
        temp = col2im(temp,[s_b s_b],[s_b s_b],'distinct');
        Outimage(i:i+s_b-1,j:j+s_b-1) = Outimage(i:i+s_b-1,j:j+s_b-1)+temp*exp((1-Outstruct(i,j).norm)/control_param);
        weight(i:i+s_b-1,j:j+s_b-1) = weight(i:i+s_b-1,j:j+s_b-1)+ exp((1-Outstruct(i,j).norm)/control_param);
    end
    j = col_num-s_b+1;
    colnum = colnum + 1;
    Outstruct(i,j).coefficient = coef(:,colnum);
    Outstruct(i,j).norm = zeronorm(Outstruct(i,j).coefficient);
    temp = resizedict*Outstruct(i,j).coefficient;
    temp = col2im(temp,[s_b s_b],[s_b s_b],'distinct');
    Outimage(i:i+s_b-1,j:j+s_b-1) = Outimage(i:i+s_b-1,j:j+s_b-1)+temp*exp((1-Outstruct(i,j).norm)/control_param);
    weight(i:i+s_b-1,j:j+s_b-1) = weight(i:i+s_b-1,j:j+s_b-1)+ exp((1-Outstruct(i,j).norm)/control_param);
deblurred_image = Outimage ./ weight;

for i=1:row_num
    for j=1:col_num
        if deblurred_image(i,j)>1
            deblurred_image(i,j)=1;
        end
        if deblurred_image(i,j)<0
            deblurred_image(i,j)=0;
        end
    end
end

end